-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
supplement the function of slice. #34172
Conversation
Thanks for your contribution! |
e017ec5
to
f41ce1c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -154,6 +160,13 @@ class StridedSliceOp : public framework::OperatorWithKernel { | |||
protected: | |||
framework::OpKernelType GetExpectedKernelType( | |||
const framework::ExecutionContext &ctx) const override { | |||
auto *in_var = ctx.InputVar("Input"); | |||
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
结合下面的code,会不会有这种情况,lodtensorarray里面tensor的place是cuda_pinned
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
TensorCopy(in_tensor, context.GetPlace(), out_tensor); | ||
} | ||
|
||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是用else分支管理代码好一些
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
// StridedSliceGrad | ||
// cannot be calculated by `framework::GradVarName("Output")`, | ||
// the dim of "Input" is used to calculate the output shape. | ||
// when set it to inplace OP, there may be some problems. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里注释说的可能存在的问题是已解决的还是TODO的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个注释说的是可能存在的问题:因为这个反向op使用Input(前向op的输入)计算输出shape,所以这个op不能是inplace op。
改成了NOTE(xx):
set_zero(dev_ctx, d_out_tensor, static_cast<T>(0)); | ||
} | ||
} | ||
return; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
@@ -176,6 +177,45 @@ def test_set_value_with_save(self): | |||
output_spec=None) | |||
|
|||
|
|||
class TestSliceSupplementCase(unittest.TestCase): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名建议准确一些,带一些功能的特征?后面可能还会追加case吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改了名字:TestSliceSupplementCase -> TestSliceSupplementSpecialCase
添加了一行注释:# unittest for slice index which abs(step)>0. eg: x[::2]
|
||
self.create_case(Net(input_size=112, array_size=13)) | ||
|
||
# TODO(weixin):Currently, the case that the start index is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种case现在的报错提示或者说warning是怎样的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
strided_slice_op.h的130行处理index,现在可支持这种情况了,但是用到这个op的其他api例如,varbase.getitem、paddle.strided_slice等op也做了类似的简单处理,这些处理是不冲突的。
if (ends[axis_index] < 0) { | ||
ends[axis_index] = 0; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[start : end : step]:处理end<-axis_size的情况。例如:len(a)=10, a[:-100:-1]
d3aaabd
to
3b6ad37
Compare
platform::is_same_place(tensor.place(), | ||
ctx.device_context().GetPlace()), | ||
true, platform::errors::InvalidArgument( | ||
"Place of context is %s. Place of context is %s. They " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有一个place是tensor的?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thx.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Function optimization
PR changes
APIs
Describe
a[::2]